import os
import yaml
import numpy as np
import argparse


def logistic_D(n, k, n0):
    """
    Compute fractal dimension D(n) using logistic fit.
    """
    return 1 + 2 / (1 + np.exp(k * (n - n0)))


def linear_gD(D, a, b):
    """
    Compute pivot weight g(D) = a*D + b.
    """
    return a * D + b


def main(config_path: str):
    """
    Compute the gauge potential Aₘᵤ(x) for each gauge group listed in
    ``config.yaml``.  For U(1) the kernel is scalar and the resulting
    Aₘᵤ array is one‑dimensional.  For SU(N) gauge groups the kernel is
    matrix‑valued (shape ``(num_links, N, N)``), and broadcasting
    ensures the pivot weights ``g(D)`` multiply each link uniformly across
    the matrix dimensions.  Each group's Aₘᵤ is saved separately as
    ``Amu_<G>.npy`` in the ``data_dir``.
    """
    # Load configuration
    with open(config_path) as f:
        cfg = yaml.safe_load(f)

    # Resolve data_dir relative to the directory of the config file
    base_dir = os.path.dirname(os.path.abspath(config_path))
    data_dir_cfg = cfg.get('data_dir', 'data')
    if os.path.isabs(data_dir_cfg):
        data_dir = data_dir_cfg
    else:
        data_dir = os.path.join(base_dir, data_dir_cfg)
    os.makedirs(data_dir, exist_ok=True)

    # Load lattice links to determine number of links
    lattice_path = os.path.join(data_dir, 'lattice.npy')
    lattice = np.load(lattice_path, allow_pickle=True)
    num_links = len(lattice)

    # Load flip counts if provided; default to ones
    fc_cfg = cfg.get('flip_counts_path')
    if fc_cfg:
        if not os.path.isabs(fc_cfg):
            fc_path = os.path.normpath(os.path.join(base_dir, fc_cfg))
        else:
            fc_path = fc_cfg
        n = np.load(fc_path)
    else:
        n = np.ones(num_links)

    # Pivot parameters
    a = cfg['pivot']['a']
    b = cfg['pivot']['b']
    k = cfg['pivot']['logistic_k']
    n0 = cfg['pivot']['logistic_n0']

    # Coupling constant
    g_coupling = cfg.get('g', 1.0)

    # Compute D and g(D)
    D_vals = logistic_D(n, k, n0)           # shape (num_links,)
    gD_vals = linear_gD(D_vals, a, b)

    # Iterate over all gauge groups defined in the configuration
    gauge_groups = cfg.get('gauge_groups', ['U1'])
    for G in gauge_groups:
        # Determine kernel path key for this gauge group
        key = f'kernel_path_{G}'
        kernel_path_cfg = cfg.get(key, None)
        # Resolve kernel path.  If not specified, fall back to the
        # conventional filename inside data_dir (kernel.npy, kernel_SU2.npy,
        # kernel_SU3.npy depending on the group).
        if kernel_path_cfg:
            if os.path.isabs(kernel_path_cfg):
                kernel_path = kernel_path_cfg
            else:
                kernel_path = os.path.normpath(os.path.join(base_dir, kernel_path_cfg))
        else:
            # Default to files in data_dir
            if G.upper() == 'U1':
                kernel_path = os.path.join(data_dir, 'kernel.npy')
            elif G.upper() == 'SU2':
                kernel_path = os.path.join(data_dir, 'kernel_SU2.npy')
            elif G.upper() == 'SU3':
                kernel_path = os.path.join(data_dir, 'kernel_SU3.npy')
            else:
                raise ValueError(f'No kernel_path provided for gauge group {G}')

        # Load kernel array (allow pickle for backward compatibility)
        K = np.load(kernel_path, allow_pickle=True)

        # Compute A_mu(x) = g * g(D) * K.  For scalar kernels broadcast
        # elementwise; for matrix kernels expand gD_vals to shape
        # (num_links, 1, 1) to broadcast across matrix indices.
        if K.ndim == 1:
            Amu = g_coupling * gD_vals * K
        else:
            Amu = g_coupling * gD_vals[:, np.newaxis, np.newaxis] * K

        # Save A_mu array to per‑group file
        out_name = f'Amu_{G}.npy'
        out_path = os.path.join(data_dir, out_name)
        np.save(out_path, Amu)
        print(f'Computed A_mu for {G} on {num_links} links, saved to {out_path}')


if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='Compute A_mu from kernel and pivot parameters')
    parser.add_argument('--config', default='config.yaml', help='Path to config file')
    args = parser.parse_args()
    main(args.config)